from runmodel import *
from clean_doc import *
import csv
import datetime
#from RunSaliency import*
from keras import backend as k
def add_csv(row):
    with open('8_21_BOW_FoldWithLoss.csv', 'a') as csvFile:
        writer = csv.writer(csvFile)
        writer.writerow(row)
        csvFile.close()

#
#We do not provide our corpus due to patient information
#there are two variables docs and cats which correspond to the text corpus and label for each patient
#these must be provded by the user
#the variable ds refers to the name of corpus used
models = [0,1,2,3,5,6,7,8]


for i in range(0,len(models)):
    mod_num = models[i]
    for j in range(0,10):
        fold_num=j
        top_acc=0
        for n in range(0,10):
            
            model, loss, acc, fpr, tpr, auc, test_preds, train_doc, test_doc, train_val, test_val, train_encoded_docs, test_encoded_docs,val_loss,train_loss = CNN_call(docs,cat,mod_num,fold_num)
            now = datetime.datetime.now()
            row = [mod_num,fold_num, acc, loss, auc,fpr,tpr,train_loss,val_loss,ds, str(now)]
            add_csv(row)
        
            if acc > top_acc:
                top_acc = acc
                import matplotlib.pyplot as plt
                #Used this link for help with ROC curve https://stackoverflow.com/questions/25009284/how-to-plot-roc-curve-in-python
                plt.title('ROC')
                plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f, for accuracy %.2f' % (auc,acc))
                plt.legend(loc = 'lower right')
                plt.plot([0, 1], [0, 1],'r--')
                plt.xlim([0, 1])
                plt.ylim([0, 1])
                plt.ylabel('True Positive Rate')
                plt.xlabel('False Positive Rate')
                nm1='bestAcc_ROC_BOW_model'+str(mod_num)+'_fold'+str(fold_num)+'.png'
                plt.savefig(nm1)
                plt.clf()
                nm2='bestAcc_ROC_BOW_model'+str(mod_num)+'_fold'+str(fold_num)+'.h5'
                model.save(nm2)
                best_test_encoded_docs=test_encoded_docs
                best_test_doc=test_doc
                best_test_val = test_val
            #if you want to run saliency with no memory leak and have saliency, have to change to run only one fold (j) at a time
            #saliency(best_test_encoded_docs,best_test_doc,best_test_val,nm2,mod_num,fold_num)
        k.clear_session()        
        
           